# a pytorch based lisv2 code
# author: ***
# email: ***

import pdb
from locale import currency
# from pso import generate
# import Queue

import numpy as np
import torch
import torch.autograd
from torch.functional import split
import torch.nn.functional as F
from scipy import optimize
# from scipy.spatial.distance import distance, squareform
from sklearn import datasets, manifold
from sklearn.manifold import SpectralEmbedding
from sklearn.metrics.pairwise import euclidean_distances, pairwise_distances
from torch import nn, set_flush_denormal
from torch.autograd import Variable
from torch.nn.modules import loss
import functools
from multiprocessing import Pool


class LISV2_MLP(nn.Module):
    def __init__(
        self,
        data,
        device,
        args,
        path,
        n_dim=2,
    ):

        super(LISV2_MLP, self).__init__()
        self.ReconstructionLoss = nn.MSELoss()

        with torch.no_grad():
            self.mseLossFunc = nn.MSELoss()
            self.device = device
            self.phis = []
            self.n_points = data.shape[0]
            self.n_dim = n_dim
            self.perplexity = args['perplexity']
            self.args = args
            self.args['NetworkStructure'][0] = data.shape[1]
            self.NetworkStructure = args['NetworkStructure']
            self.epoch = 0
            self.dataset_name = args['data_name']
            self.vList = [100] + [1] * (len(args['NetworkStructure']) - 1)
            self.gammaList = self.CalGammaF(self.vList)

            print('start to claculate sigma')
            dist = self.Distance_squared(data, data, savepath=path).float()
            print(path)
            rho, self.sigmaListlayer = self.InitSigmaSearchEverySample(
                self.gammaList, self.vList, data, dist)
            self.P = self.CalPt(
                dist,
                rho,
                self.sigmaListlayer[0].cpu(),
                gamma=self.gammaList[0],
                v=self.vList[0],
                split=1).float().detach().to(self.device)

            s = np.log10(self.args['vtrace'][0])
            e = np.log10(self.args['vtrace'][1])
            self.vListForEpoch = np.concatenate([
                np.zeros((1000, )) + 10**s,
                np.logspace(s, e, 2000),
                np.zeros((17001, )) + 10**e,
            ])

            print('start Init network')
            self.InitNetwork()
            print('start SpectralEmbedding')
            # self.CalSpectralEmbeddingInitAim(data)
            print('init LISV2_MLP mocol model, gaama is: {}'.format(
                self.gammaList))
            self.data = data.float().to(self.device)
            torch.cuda.empty_cache()

    def InitSigmaSearchEverySample(
        self,
        gammaList,
        vList,
        data,
        dist,
    ):

        distC = torch.clone(dist)
        distC[distC.le(1e-11)] = 1e16
        rho, _ = torch.min(distC, dim=1)
        # rho = torch.zeros((distC.shape[0]))

        print('start pool search')
        sigmaListlayer = [0] * len(self.args['NetworkStructure'])

        r = PoolRunner(self.n_points,
                       self.perplexity,
                       dist.detach().cpu().numpy(),
                       rho.detach().cpu().numpy(),
                       gammaList[0],
                       vList[0],
                       pow=self.args['pow'])
        sigmaListlayer[0] = torch.tensor(r.Getout()).to(self.device)

        std_dis = torch.std(rho) / np.sqrt(data.shape[1])
        print('std', std_dis)

        if std_dis > 0.2:
            for i in range(1, len(self.args['NetworkStructure'])):
                sigmaListlayer[i] = torch.zeros(data.shape[0],
                                                device=self.device) + 1
        else:
            for i in range(0, len(self.args['NetworkStructure'])):
                sigmaListlayer[i] = torch.zeros(
                    data.shape[0],
                    device=self.device) + sigmaListlayer[0].mean() * 5
        return rho, sigmaListlayer

    def InitNetwork(self, ):
        self.encoder = nn.ModuleList()
        for i in range(len(self.NetworkStructure) - 1):
            self.encoder.append(
                nn.Linear(self.NetworkStructure[i],
                          self.NetworkStructure[i + 1]))
            if i != len(self.NetworkStructure) - 2:
                self.encoder.append(nn.LeakyReLU(0.1))

        self.decoder = nn.ModuleList()
        for i in range(len(self.NetworkStructure) - 1, 0, -1):
            self.decoder.append(
                nn.Linear(self.NetworkStructure[i],
                          self.NetworkStructure[i - 1]))
            if i != 1:
                self.decoder.append(nn.LeakyReLU(0.1))
        # Map output to range (0, 1) for image datasets
        if('mnist' in self.dataset_name or 'coil' in self.dataset_name):
            self.decoder.append(nn.Sigmoid())
    def CalGammaF(self, vList):
        import scipy
        out = []
        for v in vList:
            a = scipy.special.gamma((v + 1) / 2)
            b = np.sqrt(v * np.pi) * scipy.special.gamma(v / 2)
            out.append(a / b)

        return out

    def GetInput(self, ):
        return None

    def CalPt(self, dist, rho, sigma_array, gamma, v=100, split=1):

        if torch.is_tensor(rho):
            dist_rho = (dist - rho.reshape(-1, 1)) / sigma_array.reshape(-1, 1)
        else:
            dist_rho = dist
        dist_rho[dist_rho < 0] = 0
        # print('pass1')
        sample_index_list = torch.linspace(0, dist.shape[0], int(split) + 1)
        # print('pass2')
        for i in range(split):
            # print(i)
            dist_rho_c = dist_rho[int(sample_index_list[i]
                                      ):int(sample_index_list[i + 1])]
            Pij_c = torch.pow(
                gamma * torch.pow((1 + dist_rho_c / v), -1 * (v + 1) / 2) * \
                torch.sqrt(torch.tensor(2 * 3.14)), self.args['pow'])
            if i == 0:
                Pij = Pij_c
            else:
                Pij = torch.cat([Pij, Pij_c], dim=0)
        P = Pij + Pij.t() - torch.mul(Pij, Pij.t())

        return P

    def Distance_squared(
        self,
        x,
        y,
        savepath=None,
    ):
        m, n = x.size(0), y.size(0)
        xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
        yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
        dist = xx + yy
        dist.addmm_(1, -2, x, y.t())
        d = dist.clamp(min=1e-12)
        d[torch.eye(d.shape[0]) == 1] = 1e-12
        # if savepath:
        #     np.save(savepath+'dist.npy', d.detach().cpu().numpy(), )
        
        return d

    def CE(self, P, Q):

        EPS = 1e-12
        # Q=(P+Q)/2
        losssum1 = (P * torch.log(Q + EPS)).mean()
        losssum2 = ((1-P) * torch.log(1-Q + EPS)).mean()
        losssum = -1*(losssum1 + losssum2)
        
        if torch.isnan(losssum):
            input('stop and find nan')
        return losssum

    def Loss(self, latentList, input_data_index):

        loss_ce = self.CE(P=self.P[input_data_index][:, input_data_index],
                          Q=self.CalPt(dist=self.Distance_squared(
                              latentList[0], latentList[0]),
                                       rho=0,
                                       sigma_array=1,
                                       gamma=self.gammaList[-1],
                                       v=self.vList[-1]))
        loss_rc = self.ReconstructionLoss(
            self.Generate(latentList[0].detach())[0],
            self.data[input_data_index])
        return [loss_ce, loss_rc / 10]

    def ChangeVList(self):

        epoch = self.epoch
        self.vCurent = self.vListForEpoch[epoch]
        newVList = [100]
        for i in range(len(self.args['NetworkStructure']) - 1):
            newVList.append(self.vCurent)
        self.vList = newVList
        self.gammaList = self.CalGammaF(newVList)

    def forward(self, input_data_index):

        self.ChangeVList()
        x = self.data[input_data_index]

        for i, layer in enumerate(self.encoder):
            x = layer(x)

        return [x]

    def Generate(self, latent):

        x = latent
        for i, layer in enumerate(self.decoder):
            x = layer(x)

        return [x]

    def test(self, input_data):

        self.ChangeVList()
        x = input_data.to(self.device)

        for i, layer in enumerate(self.encoder):
            x = layer(x)

        return [x]


class PoolRunner(object):
    def __init__(self, n, N_NEIGHBOR, dist, rho, gamma, v, pow=2):
        pool = Pool(processes=30)

        print(n)
        result = []
        for dist_row in range(n):

            result.append(
                pool.apply_async(sigma_binary_search,
                                 (N_NEIGHBOR, dist[dist_row], rho[dist_row],
                                  gamma, v, pow)))
        print('start calculate sigma')
        pool.close()
        pool.join()
        sigma_array = []
        for i in result:
            sigma_array.append(i.get())
        self.sigma_array = np.array(sigma_array)
        print("\nMean sigma = " + str(np.mean(sigma_array)))
        print('finish calculate sigma')

    def Getout(self, ):
        return self.sigma_array


def sigma_binary_search(fixed_k, dist_row_line, rho_line, gamma, v, pow=2):
    """
    Solve equation k_of_sigma(sigma) = fixed_k
    with respect to sigma by the binary search algorithm
    """
    sigma_lower_limit = 0
    sigma_upper_limit = 100
    for i in range(20):
        approx_sigma = (sigma_lower_limit + sigma_upper_limit) / 2
        k_value = func(approx_sigma,
                       dist_row_line,
                       rho_line,
                       gamma,
                       v,
                       pow=pow)
        if k_value < fixed_k:
            sigma_lower_limit = approx_sigma
        else:
            sigma_upper_limit = approx_sigma
        if np.abs(fixed_k - k_value) <= 1e-4:
            break
    return approx_sigma


def func(sigma, dist_row_line, rho_line, gamma, v, pow=2):
    d = (dist_row_line - rho_line) / sigma
    d[d < 0] = 0
    p = np.power(
        gamma * np.power((1 + d / v), -1 * (v + 1) / 2) * np.sqrt(2 * 3.14),
        pow)
    return np.power(2, np.sum(p))
